{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "a4b6a348-5155-4665-9616-3776bea40ff0",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "## Deep Learning - Deep Text Classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Environment Setup on databricks"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# install cloudpickle 2.0.0 to add synapse module for usage of horovod\n",
    "%pip install cloudpickle==2.0.0 --force-reinstall --no-deps"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "cd1e438b-4b6e-4d92-8cd4-0c184afe0721",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "import synapse\n",
    "import cloudpickle\n",
    "\n",
    "cloudpickle.register_pickle_by_value(synapse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "29b27e85-09c0-4e5f-8c58-af3a2bc9d373",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "! horovodrun --check-build"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "205531d2-6c06-49b4-828a-6f207371830b",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### Read Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import urllib\n",
    "\n",
    "urllib.request.urlretrieve(\n",
    "    \"https://mmlspark.blob.core.windows.net/publicwasb/text_classification/Emotion_classification.csv\",\n",
    "    \"/tmp/Emotion_classification.csv\",\n",
    ")\n",
    "\n",
    "import pandas as pd\n",
    "from pyspark.ml.feature import StringIndexer\n",
    "\n",
    "df = pd.read_csv(\"/tmp/Emotion_classification.csv\")\n",
    "df = spark.createDataFrame(df)\n",
    "\n",
    "indexer = StringIndexer(inputCol=\"Emotion\", outputCol=\"label\")\n",
    "indexer_model = indexer.fit(df)\n",
    "df = indexer_model.transform(df).drop((\"Emotion\"))\n",
    "\n",
    "train_df, test_df = df.randomSplit([0.85, 0.15], seed=1)\n",
    "display(train_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "bc46d0f5-86b6-409d-b6f9-e3deae631d50",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "1f6b513c-606b-4e32-b75e-2baaf19a11d9",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "from horovod.spark.common.store import DBFSLocalStore\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from synapse.ml.dl import *\n",
    "import uuid\n",
    "\n",
    "checkpoint = \"bert-base-uncased\"\n",
    "run_output_dir = f\"/dbfs/FileStore/test/{checkpoint}/{str(uuid.uuid4())[:8]}\"\n",
    "store = DBFSLocalStore(run_output_dir)\n",
    "\n",
    "epochs = 1\n",
    "\n",
    "callbacks = [ModelCheckpoint(filename=\"{epoch}-{train_loss:.2f}\")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "9450b5fe-ab0d-4f73-8eb2-f2428ad88b4e",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "deep_text_classifier = DeepTextClassifier(\n",
    "    checkpoint=checkpoint,\n",
    "    store=store,\n",
    "    callbacks=callbacks,\n",
    "    num_classes=6,\n",
    "    batch_size=16,\n",
    "    epochs=epochs,\n",
    "    validation=0.1,\n",
    "    text_col=\"Text\",\n",
    ")\n",
    "\n",
    "deep_text_model = deep_text_classifier.fit(train_df.limit(6000).repartition(50))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "4168b8a6-330e-4a28-949a-16954e1ea757",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "d6f97c75-b814-4138-a2a8-512bc85a6f65",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n",
    "\n",
    "pred_df = deep_text_model.transform(test_df.limit(500))\n",
    "evaluator = MulticlassClassificationEvaluator(\n",
    "    predictionCol=\"prediction\", labelCol=\"label\", metricName=\"accuracy\"\n",
    ")\n",
    "print(\"Test accuracy:\", evaluator.evaluate(pred_df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cleanup the output dir for test\n",
    "dbutils.fs.rm(run_output_dir, True)"
   ]
  }
 ],
 "metadata": {
  "application/vnd.databricks.v1+notebook": {
   "dashboards": [],
   "language": "python",
   "notebookMetadata": {
    "pythonIndentUnit": 2
   },
   "notebookName": "DeepLearning - Deep Text Classification",
   "notebookOrigID": 4390929852015145,
   "widgets": {}
  },
  "kernelspec": {
   "display_name": "Python 3.8.5 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.8.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "601a75c4c141f401603984f1538447337114e368c54c4d5b589ea94315afdca2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
